#' ---
#' title: "Analyze final 2021: Learning from Polls PID"
#' author: "Lukas F. Stoetzer"  changes: "Richard Traunmueller"
#' date: "October 2021"
#' ---

# Libraries
  library(tidyverse)
  library(xtable)
  library(MASS)

# Seed
  set.seed(36523923)  
  
# Source Code to estimate models
  source("FUN_ElicitedBeliefsMLE.R")
  source("FUN_Estimation.R")
  
  select <- dplyr::select
  
# Prepare Data =========

  # Load Data
  dat <- read.csv("data_exp_partisian.csv",row.names = 1)
  dat_long <- read.csv("data_exp_partisian_long.csv",row.names = 1) 
  
  # Descriptive Stats
  dat$university <- ifelse(dat$educ=="University degree", 1, 0)
  dat$polint <- ifelse(dat$polint=="Very interested", 1, 0)
  dat$pid3REP <- ifelse(dat$pid3=="Republican", 1, 0)
  dat$pid3IND <- ifelse(dat$pid3=="Independet", 1, 0)
  dat$pid3DEM <- ifelse(dat$pid3=="Independet", 1, 0)

  # Summary Statistics
  tab_sum <- dat %>% 
    dplyr::select(female, age, university, polint,pid3REP, pid3DEM) %>% # select variables to summarise
    summarise(across(everything(), .f = list(mean = mean,
                                             min = min,
                                             median = median,
                                             max = max, 
                                             sd = sd), na.rm = T)) %>% 
    gather(stat, val) %>%
    separate(stat, into = c("var", "stat"), sep = "_") %>%
    spread(stat, val) %>%
    dplyr::select(var, min, median, max, mean, sd) # reorder columns
  
  xtable(tab_sum)
  
  # Balance
  balance.test.1 <- filter(dat, cand!="Z") %>% 
    group_by(cand) %>% 
    summarize(N=n(), female=mean(female, na.rm=T), 
              age=mean(age, na.rm=T), 
              educ=mean(university, na.rm=T), 
              polint=mean(polint, na.rm=T),
              pidDEM= mean(pid3DEM, na.rm=T),
              pidREP= mean(pid3REP, na.rm=T)
              )

  
  xtable(balance.test.1)


  # Add Poll 
  dat_long <- filter(dat_long, cand!="Z") %>% # Ignore Zick-Zack Condition
              #mutate(cand = droplevels(cand)) %>%
              mutate(pollY = case_when(
                       time == 1  ~ 51,
                       time == 2  ~ 54, 
                       time == 3  ~ 58,
                       TRUE ~ NaN)) %>%
              mutate(pm = pu) %>%
              mutate(cand = factor(cand, levels=c("D","R"),
                                          labels = c("Democrat Cand. Winning",
                                                     "Republican Cand. Winning")),
                     pid3 = factor(pid3, levels=c("Democrat","Independet","Republican"),
                                         labels = paste(c("Democrat","Independet","Republican"),
                                                        "Resp.")))
                
  
  
# Estimate Time separate Model ============
  
  # Estimate Model for different scenarios
  res_time_mod <-  filter(dat_long) %>%
    drop_na(ex,lo,up,pl,pm) %>%
    group_by(source,cand, pid3,time) %>%
    do(est_mod_time(.))
  
  # Log-Likelihood
  log_lik_m1 <- res_time_mod %>% 
    group_by(source, cand,pid3) %>%
    summarise("log_lik" = sum(log_lik)) %>%
    mutate("model" = "separate_each_time",
           "par" = 4*3)
  
  # Calculate Log-liklihood
  res_time_mod <- res_time_mod %>% dplyr::select(-log_lik) 
  
  # Plot Results independent model (Intervals) 
  
  # Get Interval
  df_est <- res_time_mod %>%
    mutate(lower_in = qbeta(0.25,alpha,beta),
           upper_in = qbeta(0.75,alpha,beta),
           lower_in2 = qbeta(0.05,alpha,beta),
           upper_in2 = qbeta(0.95,alpha,beta),
           mean  =alpha/(alpha+beta)) 
  
  
  # Prepare Data for Plot
  df_est$time <-  factor(df_est$time, levels = c(0,1,2,3), labels=c("Prior", "1. Poll","2. Poll", "3. Poll"))
  filter(df_est, time=="Prior") %>% arrange(pid3)
  
  # Plot Data
  ggplot(filter(df_est,pid3 != "Independet Resp.")) +
    geom_linerange(aes(x=time,ymin=lower_in,ymax=upper_in,group=time,col=pid3,group=pid3), 
                   position = position_dodge2(0.5),
                   size=3,alpha=0.7) +
    geom_linerange(aes(x=time,ymin=lower_in2,ymax=upper_in2,group=time,col=pid3,group=pid3), 
                   position = position_dodge2(0.5), 
                   size=3,alpha=0.4) +
    geom_point(aes(x=time,y=mean,col=pid3,group=pid3), 
               position = position_dodge2(0.5),  size=3.5) +
    facet_grid(source ~   cand) +
    scale_colour_manual("",values = c("Democrat Resp." = "#0015bc",
  #                                 "Independet Rep." = "#000000",
                                   "Republican Resp." = "#e9141d")) +
    theme_light() +  xlab("") +
    theme(axis.title.y=element_blank(),
          text = element_text(size=20),
          strip.text.y = element_text(angle = 360),
          legend.position = "bottom",
          # panel.grid.major = element_blank(),
          panel.grid.minor = element_blank()
    )
 
  ggsave("fig_beliefs_partisian_time.pdf",width=12,height = 10)
  
  # 
  ggplot(df_est) +
    geom_linerange(aes(x=time,ymin=lower_in,ymax=upper_in,
                                          col=pid3,group=pid3), 
                   position = position_dodge2(0.5),
                   size=3,alpha=0.7) +
    geom_linerange(aes(x=time,ymin=lower_in2,ymax=upper_in2,col=pid3,group=pid3), 
                   position = position_dodge2(0.5), 
                   size=3,alpha=0.4) +
    geom_point(aes(x=time,y=mean,col=pid3,group=pid3), 
               position = position_dodge2(0.5),  size=3.5) +
    facet_grid(source ~   cand) +
    scale_colour_manual("",values = c("Democrat Resp." = "#0015bc",
                                      "Independet Rep." = "#000000",
                                      "Republican Resp." = "#e9141d")) +
    theme_light() +  xlab("") +
    theme(axis.title.y=element_blank(),
          text = element_text(size=20),
          strip.text.y = element_text(angle = 360),
          legend.position = "bottom",
          # panel.grid.major = element_blank(),
          panel.grid.minor = element_blank()
    )
  
  ggsave("fig_beliefs_partisian_time_all.pdf",width=12,height = 10)
  
  
# Estimate Model Dynamic Learning Model ============
  
  # Estimate Model for different scenarios
  res_dyn_mod <-  filter(dat_long) %>%
    drop_na(ex,lo,up,pl,pm) %>%
    group_by(source, cand, pid3) %>%
    do(est_mod(.,pid = T))

  # Calculate Log-likelihood
  log_lik_m2 <- res_dyn_mod %>% 
    group_by(source, cand,pid3) %>%
    summarise("log_lik" = sum(log_lik)) %>%
    mutate("model" = "fixpriors_dynamic",
           "par" = 3 + 3)
  
  # Exclude Log-likelihood
  res_dyn_mod <- res_dyn_mod %>% dplyr::select(-log_lik) 
  
  # Rate of Adaption
  res_dyn_mod <- mutate(res_dyn_mod, d= (rho/(rho + delta)))

  # table
  xtable(res_dyn_mod)
  
  
# Evaluate Learning parameters =========
  
  # Function to sample learning types
  sample_learningtype <- function(d, REP=3000){
    est <- est_mod(d, ret = "all", pid=T)
    S <- mvrnorm(REP,est[[1]]$par_raw,est[[1]]$vcov)
    S <- apply(S, 1, FUN = function(p) c(1/(1 + exp(-p[1])),exp(p[2]),exp(p[3])))
    df_plot_learn <- data.frame("delta" = S[1,], "rho" = S[2,])
    names(df_plot_learn) <- c("delta","rho")
    
    return(df_plot_learn)
  }
  
  # Calculate learn Type (Sample from Likelihood)
  df_plot_learn <- filter(dat_long) %>% 
    group_by(cand, source, pid3) %>%
    do(sample_learningtype(.)) %>%
    mutate("cond" = interaction(cand,source))

  
  df_plot_learn_all <- df_plot_learn %>% 
    mutate(d = rho / (delta + rho)) %>%
    group_by(cond,cand,source,pid3) %>%
    summarise(d_mid = mean(d),
              d_low = quantile(d,0.025),
              d_high= quantile(d, 0.975),
              delta_mid = mean(delta),
              delta_low = quantile(delta,0.025),
              delta_high = quantile(delta, 0.975),
              rho_mid = mean(rho),
              rho_low = quantile(rho,0.025),
              rho_high = quantile(rho, 0.975)
    ) %>%
    gather(var,val,-cand,-source,-cond,-pid3) %>%
    separate(var, c("est","lvl")) %>%
    spread(lvl,val) %>%
    mutate(est = factor(est, levels = c("d","delta","rho"),
                        labels = c("Standardized Rate of Adaption",
                                   "Prior Stickiness Rate",
                                   "Sample Scale Factor")))

  ggplot(filter(df_plot_learn_all, pid3!="Independet Resp.")) +
    geom_pointrange(aes(x=cond, y=mid, ymin=low, ymax=high, col=pid3,shape=pid3),size=1.1, 
                    position = position_dodge2(0.4)) +
    coord_flip() +
    facet_wrap(~ est, ncol=1 ) +
    theme_bw() + ylab("") + xlab("") +
    ylim(c(0,1)) +
    theme(text = element_text(size=20), 
          strip.text.y = element_text(angle = 360),
          legend.position = "top",
          legend.title = element_blank(),
          # panel.grid.major = element_blank(),
          panel.grid.minor = element_blank())  +
    scale_colour_manual(values = c("Democrat Resp." = "#0015bc",
                                   #"Independet Resp." = "#000000",
                                   "Republican Resp." = "#e9141d"))
  
  ggsave("fig_beliefs_partisan_all.pdf",width=10,height = 6)
  
  ggplot(filter(df_plot_learn_all)) +
    geom_pointrange(aes(x=cond, y=mid, ymin=low, ymax=high, col=pid3,shape=pid3),size=1.1, 
                    position = position_dodge2(0.4)) +
    coord_flip() +
    facet_wrap(~ est, ncol=1 ) +
    theme_bw() + ylab("") + xlab("") +
    ylim(c(0,1)) +
    theme(text = element_text(size=20), 
          strip.text.y = element_text(angle = 360),
          legend.position = "top",
          legend.title = element_blank(),
          # panel.grid.major = element_blank(),
          panel.grid.minor = element_blank())  +
    scale_colour_manual(values = c("Democrat Resp." = "#0015bc",
                                   "Independet Resp." = "#000000",
                                   "Republican Resp." = "#e9141d"))
  
  ggsave("fig_beliefs_partisan_allresp.pdf",width=10,height = 6)
  
  
# Model Fit =========
  
  
  # Numbers
  cases <- dat_long %>%
    group_by() %>%
    summarise("N" = n()/4)
  
  # Log Likelihood
  log_lik <- bind_rows(log_lik_m1,log_lik_m2) %>%
    group_by(model) %>%
    summarise("log_lik" = sum(log_lik)) %>%
    mutate("N" = cases$N,
           "par" = c(6*4*3,12*4*3)) %>%
    mutate(aic = 2*par - 2 *log_lik,
           bic = par*log(N) - 2*log_lik) 
  
  
  # 
  table_modelfit <- log_lik %>% 
    mutate(model = case_when(
      model == "separate_each_time" ~ "independent",
      model == "fixpriors_dynamic" ~ "dynamic",
    ))
  
  print(xtable::xtable(table_modelfit), include.rownames=FALSE)
  